#!/usr/bin/env python3

import itertools
import numpy as np
import rospy
import sys
import time

from transformations import euler_from_quaternion, quaternion_from_euler
from cv_bridge import *
from sensor_msgs.msg import *
from std_msgs.msg import *
from nav_msgs.msg import *
from geometry_msgs.msg import *

DEBUG = "--debug" in sys.argv

if DEBUG:
    import cv2

class AutopilotNode(object):
    def __init__(self,
        ns_camera = "/camera_front",
        ns_vehicle = "/vehicle",
        node_name = "autopilot_node",
    ):
        rospy.init_node(node_name)

        self.is_initialized = False
        self.local_map_flat = None
        self.local_map_yaw = None
        self.last_angle = 0.0
        self.plan_inverse_radius = 0.0
        self.plan_time = 0.0
        self.vehicle_width = 4.0 # m

        # parameters

        self.param_steering_factor = rospy.get_param("~steering_factor", 2000.0) # degrees per inverse meter
            # e.g. 2000.0 means:
            # 2 degrees for 1000 m turning radius
            # 20 degrees for 100 m turning radius
            # 200 degrees for 10 m turning radius

        while not rospy.has_param("%s/semantic_categories" % ns_camera):
            rospy.loginfo_throttle(10, "waiting for categories")
            rospy.sleep(1)
        self.categories = rospy.get_param("%s/semantic_categories" % ns_camera)

        # publishers

        self.pub_steering_wheel_target = rospy.Publisher(
            "%s/steering_wheel_target" % ns_vehicle,
            Float32, queue_size = 1
        )

        self.pub_image = rospy.Publisher(
            "/autopilot_image",
            Image, queue_size = 1
        )

        # subscribers

        self.sub_grid = rospy.Subscriber(
            "semantic_grid",
            OccupancyGrid, self.on_grid, queue_size = 1
        )
        self.sub_steering_wheel_angle = rospy.Subscriber(
            "%s/steering_wheel_angle" % ns_vehicle,
            Float32, self.on_steering_wheel_angle, queue_size = 1
        )

        # color map for visualization
        # one color for each category plus the final [255,255,255] is a fake
        # category to display white-colored overlays
        self.color_map = np.array(
            [list(reversed(c["color"])) for c in self.categories] + [[255,255,255]],
            dtype=np.uint8
        )

        self.is_initialized = True

    def on_steering_wheel_angle(self, msg):
        self.last_angle = msg.data

    def on_grid(self, msg):
        self.local_map_flat = np.fromstring(bytes(msg.data), np.int8).reshape((msg.info.height, msg.info.width))
        self.pixels_per_meter = 1/msg.info.resolution

        self.local_map_offset_x = msg.info.origin.position.x
        self.local_map_offset_y = msg.info.origin.position.y
        self.local_map_offset_px = int(msg.info.origin.position.x * self.pixels_per_meter)
        self.local_map_offset_py = int(msg.info.origin.position.y * self.pixels_per_meter)
        dummy_roll, dummy_pitch, self.local_map_yaw = euler_from_quaternion(( \
             msg.info.origin.orientation.w,
             msg.info.origin.orientation.x,
             msg.info.origin.orientation.y,
             msg.info.origin.orientation.z,
        ))

    def plan(self):
        if self.local_map_yaw is None:
            return

        xs = []
        ys = []
        inverse_radiuses = []
        for slice_delta_x in np.arange(0, 50, 0.125):
            slice_x = self.local_map_offset_x + slice_delta_x
            slice_y = self.local_map_offset_y
            slice_px = int(slice_x * self.pixels_per_meter)
            slice_py = int(slice_y * self.pixels_per_meter)
            local_map_slice = self.local_map_flat[:, slice_px]
            indexes = np.where(local_map_slice == 1)
            if indexes[0].size <= 4:
                continue
            slice_mean = np.median(indexes[0])
            slice_std = np.std(indexes[0])
            if slice_std < 0.3 * self.pixels_per_meter:
                continue
            if slice_std > 2.0 * self.pixels_per_meter:
                continue
            print("slice_delta_x:", slice_delta_x)
            print("mean", slice_mean, "std", slice_std)
            y_displacement = (slice_mean - self.local_map_offset_y * self.pixels_per_meter) / self.pixels_per_meter
            print("y_displacement", y_displacement)

            inverse_radius = 2 * y_displacement / (slice_delta_x ** 2 + y_displacement ** 2)
            inverse_radiuses.append(inverse_radius)

        if len(inverse_radiuses) >= 5:
            self.plan_inverse_radius = np.average(inverse_radiuses, weights = np.arange(len(inverse_radiuses))**3 )

        self.plan_time = time.time()

    def spin(self):
        rate = rospy.Rate(50)
        seq = 0
        while not rospy.is_shutdown():
            rate.sleep()

            if not self.is_initialized:
                continue

            if self.local_map_flat is None:
                continue

            self.plan()

            if time.time() - self.plan_time > 0.5:
                rospy.logwarn("planner did not update")
                continue

            display_image = self.local_map_flat.copy()

            # plot current vehicle direction
            scale = np.linspace(0.0, 200.0, 100)
            dscale = np.diff(scale)
            x = scale * np.cos(self.local_map_yaw)
            y = scale * np.sin(self.local_map_yaw)
            offsets_x = (x * self.pixels_per_meter + self.local_map_offset_x * self.pixels_per_meter).astype(np.uint16)
            offsets_y = (y * self.pixels_per_meter + self.local_map_offset_y * self.pixels_per_meter).astype(np.uint16)
            where_valid = (offsets_x >= 0) & (offsets_x < self.local_map_flat.shape[0]) & \
                      (offsets_y >= 0) & (offsets_y < self.local_map_flat.shape[0])
            offsets_x = offsets_x[where_valid]
            offsets_y = offsets_y[where_valid]
            display_image[offsets_y, offsets_x] = len(self.categories)

            # plot steering target
            scale = np.linspace(0.0, 100.0, 25)
            dscale = np.diff(scale)
            angles = self.local_map_yaw + np.arange(scale.shape[0]) * dscale[0] * self.plan_inverse_radius
            dx = dscale * np.cos(angles[1:])
            dy = dscale * np.sin(angles[1:])
            xs = np.cumsum(dx)
            ys = np.cumsum(dy)
            for x, y in zip(xs, ys):
                center_x = (x * self.pixels_per_meter + self.local_map_offset_x * self.pixels_per_meter).astype(np.uint16)
                center_y = (y * self.pixels_per_meter + self.local_map_offset_y * self.pixels_per_meter).astype(np.uint16)
                if center_y - 2 > 0 and \
                    center_x - 2 > 0 and \
                    center_y + 2 < display_image.shape[0] and \
                    center_x + 2 < display_image.shape[1]:
                    display_image[
                        int(center_y - 2) : int(center_y + 2),
                        int(center_x - 2) : int(center_x + 2)
                    ] = len(self.categories)
                #where_valid = (offsets_x >= 0) & (offsets_x < self.local_map_flat.shape[0]) & \
                #          (offsets_y >= 0) & (offsets_y < self.local_map_flat.shape[0])

            # plot steering target
            scale = np.linspace(0.0, 100.0, 50)
            dscale = np.diff(scale)
            angles = self.local_map_yaw + np.arange(scale.shape[0]) * dscale[0] * self.last_angle / self.param_steering_factor
            dx = dscale * np.cos(angles[1:])
            dy = dscale * np.sin(angles[1:])
            x = np.cumsum(dx)
            y = np.cumsum(dy)
            offsets_x = (x * self.pixels_per_meter + self.local_map_offset_x * self.pixels_per_meter).astype(np.uint16)
            offsets_y = (y * self.pixels_per_meter + self.local_map_offset_y * self.pixels_per_meter).astype(np.uint16)
            where_valid = (offsets_x >= 0) & (offsets_x < self.local_map_flat.shape[0]) & \
                      (offsets_y >= 0) & (offsets_y < self.local_map_flat.shape[0])
            offsets_x = offsets_x[where_valid]
            offsets_y = offsets_y[where_valid]
            display_image[offsets_y, offsets_x] = len(self.categories)

            steering_wheel_target = self.plan_inverse_radius * self.param_steering_factor
            m = Float32()
            m.data = steering_wheel_target
            self.pub_steering_wheel_target.publish(m)

            # plot current location

            px = int(self.local_map_offset_x * self.pixels_per_meter)
            py = int(self.local_map_offset_y * self.pixels_per_meter)
            display_image[py-2:py+3,px-2:px+3] = len(self.categories)

            # colorize map

            display_image_color = self.color_map[display_image[::-1, :]]

            if self.pub_image.get_num_connections() > 0:
                msg_image = cv2_to_imgmsg(display_image_color, encoding = "bgr8")
                self.pub_image.publish(msg_image)

            if DEBUG:
                cv2.imshow("autopilot", display_image_color)
                cv2.waitKey(1)

if __name__ == "__main__":
    node = AutopilotNode()
    node.spin()

